package com.tianji.aigc.agent;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil;
import com.tianji.aigc.config.ToolResultHolder;
import com.tianji.aigc.constants.Constant;
import com.tianji.aigc.enums.ChatEventTypeEnum;
import com.tianji.aigc.service.ChatService;
import com.tianji.aigc.service.ChatSessionService;
import com.tianji.aigc.vo.ChatEventVO;
import com.tianji.common.utils.UserContext;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatModel;
import reactor.core.publisher.Flux;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
public abstract class AbstractAgent implements Agent{

    @Resource
    private ChatSessionService chatSessionService;
    @Resource
    private ChatModel dashscopeChatModel;
    @Resource
    private ChatMemory chatMemory;
    @Resource
    private Advisor loggerAdvisor;

//    输出结束的标识
    public static final ChatEventVO STOP_EVENT = ChatEventVO.builder().eventType(ChatEventTypeEnum.STOP.getValue()).build();


//    存储大模型的生成状态, 这里采用ConcurrentHashMap是确保线程安全的
    // 目前的版本暂时用Map实现，如果考虑分布式环境的话，可以考虑用redis来实现
    public static final Map<String, Boolean> GENERATE_STATUS = new ConcurrentHashMap<>();


    @Override
    public String process(String question, String sessionId) {
        var requestId = this.generateRequestId();

        // 获取用户id
        var userId = UserContext.getUser();

        var result = this.getChatClientRequest(sessionId, requestId, question)
                .call()
                .content();

        this.updateTitleOrUpdateTime(question, sessionId, result, userId);
        return result;
    }

    public Flux<ChatEventVO> processStream(String question, String sessionId) {
        var requestId = this.generateRequestId();

//        获取对话id
        var conversationId = ChatService.getConversationId(sessionId);
//        大模型输出内容的缓存器, 用于在输出中断后的数据存储
        StringBuilder outputBuilder = new StringBuilder();
        // 获取用户id
        var userId = UserContext.getUser();

        return this.getChatClientRequest(sessionId, requestId, question)
                .stream()
                .chatResponse()
                .doFirst(() -> { // 输出开始标记正在输出
                    GENERATE_STATUS.put(sessionId, true);
                })
                .doOnComplete(() -> { // 输出结束清除标记
                    GENERATE_STATUS.remove(sessionId);
                })
                .doOnError(throwable -> GENERATE_STATUS.remove(sessionId)) // 错误时清除标记
                .doOnCancel(() -> {
//                    当输出被取消时, 保存输出的内容到历史记录中
                    this.saveStopHistoryRecord(conversationId, outputBuilder.toString());
                })
                .doFinally(signalType -> {
//                    需要更新对话的标题 或 更新时间
                    this.updateTitleOrUpdateTime(question, sessionId, outputBuilder.toString(), userId);
                })
//                输出过程中, 判断是否正在输出, 如果正在输出, 则继续输出, 否则结束输出
                .takeWhile(s -> Optional.ofNullable(GENERATE_STATUS.get(sessionId)).orElse(false))
                .map(chatResponse -> {
                    // 对于响应结果进行处理，如果是最后一条数据，就把此次消息id放到内存中
                    // 主要用于存储消息数据到 redis中，可以根据消息di获取的请求id，再通过请求id就可以获取到参数列表了
                    // 从而解决，在历史聊天记录中没有外参数的问题
                    var finishReason = chatResponse.getResult().getMetadata().getFinishReason();
                    if (StrUtil.equals(Constant.STOP, finishReason)) {
                        var messageId = chatResponse.getMetadata().getId();
                        ToolResultHolder.put(messageId, Constant.REQUEST_ID, requestId);
                    }

//                    获取大模型的输出的内容
                    String text = chatResponse.getResult().getOutput().getText();
//                    追加到输出内容中
                    outputBuilder.append(text);
//                    封装响应对象
                    return ChatEventVO.builder()
                            .eventData(text)
                            .eventType(ChatEventTypeEnum.DATA.getValue())
                            .build();
                })
                .concatWith(Flux.defer(() -> {
//                    通过请求id获取到参数列表, 如果不为空, 就将其追加到返回结果中
                    var map = ToolResultHolder.get(requestId);
                    if (CollUtil.isNotEmpty(map)) {
                        ToolResultHolder.remove(requestId); // 清除参数列表

//                        响应给前端的参数数据
                        ChatEventVO chatEventVO = ChatEventVO.builder()
                                .eventData(map)
                                .eventType(ChatEventTypeEnum.PARAM.getValue())
                                .build();
                        return Flux.just(chatEventVO, STOP_EVENT);
                    }
                    return Flux.just(STOP_EVENT);
                }));
    }

    private void updateTitleOrUpdateTime(String question, String sessionId, String result, Long userId) {
        var content = StrUtil.format("""
                             ------------
                            USER:{} \n
                            ASSISTANT:{}
                            ------------
                            """, question, result);
        this.chatSessionService.update(sessionId, content, userId);
    }


    /**
     * 保存输出停止的记录
     * @param conversationId      会话id
     * @param content             大模型输出内容
     */
    private void saveStopHistoryRecord(String conversationId, String content) {
        this.chatMemory.add(conversationId, new AssistantMessage(content));
    }

    private ChatClient.ChatClientRequestSpec getChatClientRequest(String sessionId, String requestId, String question) {
        return ChatClient.builder(this.dashscopeChatModel).build()
                .prompt()
                .system(promptSystem -> promptSystem
                        .text(this.systemMessage())
                        .params(this.systemMessageParams())
                )
                .advisors(advisor -> advisor
                        .advisors(this.defaultAdvisors(question)) // 添加默认的增强器
                        .advisors(this.advisors(question)) // 添加自定义的增强器
                        .params(this.advisorParams(sessionId, requestId))
                )
                .tools(this.tools())
                .toolContext(this.toolContext(sessionId, requestId))
                .user(question);
    }

    private String generateRequestId() {
        return IdUtil.fastSimpleUUID();
    }

    @Override
    public List<Advisor> defaultAdvisors(String question) {
        return List.of(this.loggerAdvisor);
    }

    @Override
    public Map<String, Object> advisorParams(String sessionId, String requestId) {
        String conversationId = ChatService.getConversationId(sessionId);
        return Map.of(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId);
    }

    @Override
    public void stop(String sessionId) {
        GENERATE_STATUS.remove(sessionId);
    }
}
